{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_compression as tfc\n",
    "import tensorflow_probability as tfp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from time import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class H_a(tf.keras.layers.Layer):\n",
    "    \"\"\"Encoder network for the Hyperprior.\"\"\"\n",
    "    \n",
    "    def __init__(self, N):\n",
    "        \"\"\"Initializes the encoder.\"\"\"\n",
    "        \n",
    "        super(H_a, self).__init__()\n",
    "        self.N      = N\n",
    "\n",
    "        self.conv1  = tf.keras.layers.Conv2D(self.N, 3, strides=1, activation='relu')\n",
    "        self.conv2  = tf.keras.layers.Conv2D(self.N, 5, strides=2, activation='relu')\n",
    "        self.conv3  = tf.keras.layers.Conv2D(self.N, 5, strides=2)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        \"\"\"Forward pass of the encoder.\"\"\"\n",
    "        x = tf.abs(inputs)\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)\n",
    "        z = self.conv3(x)\n",
    "        return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class H_s(tf.keras.layers.Layer):\n",
    "    \"\"\"Decocer network for the Hyperprior.\"\"\"\n",
    "    \n",
    "    def __init__(self, N, M):\n",
    "        \"\"\"Initializes the decoder.\"\"\"\n",
    "        \n",
    "        super(H_s, self).__init__()\n",
    "        self.N      = N\n",
    "        self.M      = M\n",
    "        \n",
    "        self.conv1  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2, activation='relu')\n",
    "        self.conv2  = tf.keras.layers.Conv2DTranspose(self.N, 5, strides=2, activation='relu')\n",
    "        self.conv3  = tf.keras.layers.Conv2DTranspose(self.M, 3, strides=1, activation='relu')\n",
    "\n",
    "    def call(self, inputs):\n",
    "        \"\"\"Forward pass of the decoder.\"\"\"\n",
    "        x = self.conv1(inputs)\n",
    "        x = self.conv2(x)\n",
    "        z = self.conv3(x)\n",
    "        return z"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "220aeacb52331f97f7780d84158167722a7cc0056b78b1c27b32a8f95243ca77"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
